import cv2
import mediapipe as mp
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.signal import savgol_filter  # for smoothing

###############################
#     CONFIG / SETTINGS       #
###############################
USER_VIDEO_PATH = 'liu_forehand.mp4'  # Path to my video
PRO_VIDEO_PATH = 'iwa_forehand.mp4'   # Path to the pro's video
RESIZE_WIDTH = 640
RESIZE_HEIGHT = 360
DRAW_PARTIAL_PATH = True
CAPTURE_INTERVAL_SECS = 0.2

USER_SCREENSHOT_PATH = "user_screenshot.png"
PRO_SCREENSHOT_PATH  = "pro_screenshot.png"
os.makedirs("user_captures", exist_ok=True)
os.makedirs("pro_captures", exist_ok=True)

mp_pose = mp.solutions.pose
pose = mp_pose.Pose(
    static_image_mode=False,
    model_complexity=1,
    enable_segmentation=False,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)

###############################
#   DTW and Smoothing Utils   #
###############################

def simple_dtw(seq1, seq2):
    """
    A minimal dynamic time warping for 1D sequences.
    Returns the DTW distance.
    """
    n = len(seq1)
    m = len(seq2)
    # store cost in a 2D array
    dtw_matrix = np.zeros((n+1, m+1)) + np.inf
    dtw_matrix[0,0] = 0

    for i in range(1, n+1):
        for j in range(1, m+1):
            cost = abs(seq1[i-1] - seq2[j-1])  # distance between points
            # take last min and add cost
            dtw_matrix[i,j] = cost + min(
                dtw_matrix[i-1,j],    # deletion
                dtw_matrix[i,j-1],    # insertion
                dtw_matrix[i-1,j-1]   # match
            )
    return dtw_matrix[n,m]

def smooth_speed_curve(speeds, window_length=7, polyorder=2, speed_threshold=550):
    """
    Apply Savitzky–Golay filter to smooth out single-frame glitches,
    preserving any extremely high values above 'speed_threshold'.
    """
    filled = []
    last_valid = 0.0
    for s in speeds:
        if s is not None:
            filled.append(s)
            last_valid = s
        else:
            filled.append(last_valid)
    filled = np.array(filled, dtype=float)

    high_speed_mask = (filled >= speed_threshold)
    if len(filled) >= window_length:
        if window_length > len(filled):
            window_length = len(filled) if len(filled) % 2 == 1 else (len(filled) - 1)
            window_length = max(window_length, 3)
        smoothed = savgol_filter(filled, window_length=window_length, polyorder=polyorder)
        smoothed[high_speed_mask] = filled[high_speed_mask]
    else:
        smoothed = filled
    return smoothed.tolist()

###############################
#   Pose + Flow Functions     #
###############################

def process_frame(frame):
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = pose.process(rgb_frame)
    if results.pose_landmarks:
        return results.pose_landmarks.landmark
    return None

def track_racket(landmarks, frame_width, frame_height):
    wrist = landmarks[mp_pose.PoseLandmark.LEFT_WRIST.value]
    wrist_x = int(wrist.x * frame_width)
    wrist_y = int(wrist.y * frame_height)
    return (wrist_x, wrist_y)

def measure_optical_flow_speed(prev_gray, curr_gray, wrist_coord):
    if prev_gray is None or curr_gray is None:
        return None
    flow = cv2.calcOpticalFlowFarneback(
        prev_gray, curr_gray, None,
        0.5, 3, 15, 3, 5, 1.2, 0
    )
    (wx, wy) = wrist_coord
    h, w = flow.shape[:2]
    if wx < 0 or wx >= w or wy < 0 or wy >= h:
        return None
    flow_vec = flow[wy, wx]
    flow_x = flow_vec[0]
    flow_y = flow_vec[1]
    speed_px_frame = np.sqrt(flow_x**2 + flow_y**2)
    return speed_px_frame

def draw_shoulder_markers(frame, landmarks, frame_width, frame_height):
    if landmarks is None:
        return frame
    left_shoulder  = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value]
    right_shoulder = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER.value]
    l_x = int(left_shoulder.x * frame_width)
    l_y = int(left_shoulder.y * frame_height)
    r_x = int(right_shoulder.x * frame_width)
    r_y = int(right_shoulder.y * frame_height)
    cv2.circle(frame, (l_x, l_y), 10, (0,255,0), -1)
    cv2.circle(frame, (r_x, r_y), 10, (0,0,255), -1)
    return frame

def calculate_reference_distance(landmarks, frame_width, frame_height):
    left_shoulder  = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value]
    right_shoulder = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER.value]
    l_x = int(left_shoulder.x * frame_width)
    l_y = int(left_shoulder.y * frame_height)
    r_x = int(right_shoulder.x * frame_width)
    r_y = int(right_shoulder.y * frame_height)
    distance = np.linalg.norm(np.array([l_x,l_y]) - np.array([r_x,r_y]))
    return distance

###############################
#   Metrics: Mean Vel, Accel  #
###############################

def compute_mean_velocity(speeds):
    """Compute the average speed"""
    valid = [s for s in speeds if s is not None]
    if not valid:
        return 0
    return sum(valid)/len(valid)

def compute_mean_acceleration(speeds):
    """
    Discrete acceleration is the absolute frame-to-frame difference in speed,
    then average. speeds is in px/s.
    """
    valid = [s for s in speeds if s is not None]
    if len(valid) < 2:
        return 0
    diffs = []
    for i in range(len(valid)-1):
        acc = abs(valid[i+1] - valid[i])  # absolute difference
        diffs.append(acc)
    return sum(diffs)/len(diffs)


def analyze_swing(swing_points):

    if len(swing_points) < 2:
        return None  # need at least two valid points

    start_point = swing_points[0]
    end_point   = swing_points[-1]
    if start_point is None or end_point is None:
        return None

    # dx is unchanged
    dx = end_point[0] - start_point[0]
    # Flip the sign on dy so that "down" in the image becomes negative,
    # making "up" positive
    dy = (start_point[1] - end_point[1])

    # Distance
    distance = np.sqrt(dx**2 + dy**2)

    # Angle in radians
    angle_radians = np.arctan2(dy, dx)

    # Convert to [0..2π) if negative
    if angle_radians < 0:
        angle_radians += 2.0 * np.pi

    angle_degrees = np.degrees(angle_radians)

    return (distance, angle_degrees)


def compare_swings(user_swing, pro_swing):
    user_distance, user_angle = user_swing
    pro_distance, pro_angle = pro_swing
    dist_diff = abs(user_distance - pro_distance)
    angle_diff = abs(user_angle - pro_angle)
    return dist_diff, angle_diff

def normalize_swing_points(swing_points, reference_distance):
    """Divide all (x,y) by 'reference_distance' to get dimensionless coords."""
    if reference_distance is None or reference_distance == 0:
        return swing_points
    normalized = []
    for pt in swing_points:
        if pt is None:
            normalized.append(None)
        else:
            nx = pt[0] / reference_distance
            ny = pt[1] / reference_distance
            normalized.append((nx, ny))
    return normalized

def align_trajectories(user_points, pro_points):
    if not user_points or not pro_points:
        return user_points, pro_points
    user_start = user_points[0]
    pro_start = pro_points[0]
    offset_x = user_start[0] - pro_start[0]
    offset_y = user_start[1] - pro_start[1]
    aligned_pro = []
    for p in pro_points:
        if p is None:
            aligned_pro.append(None)
        else:
            aligned_pro.append((p[0] + offset_x, p[1] + offset_y))
    return user_points, aligned_pro
###############################
#     Pose+Flow Pipeline      #
###############################

def process_video(video_path, screenshot_path, window_title, skip_frames=0):
    """
    Return (swing_points, reference_distance, speeds_list) => raw speeds in px/s
    """
    import mediapipe as mp

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error opening video: {video_path}")
        return None, None, []

    frame_width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps          = cap.get(cv2.CAP_PROP_FPS)
    if fps <= 0:
        fps = 30

    # skip frames if needed
    for _ in range(skip_frames):
        ret, _ = cap.read()
        if not ret:
            break

    swing_points = []
    speeds = []
    path_overlay = None
    reference_distance = None
    final_frame = None
    prev_gray = None
    frame_index = 0

    frames_per_capture = int(fps * CAPTURE_INTERVAL_SECS)

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame = cv2.resize(frame, (frame_width, frame_height))
        curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        if path_overlay is None:
            path_overlay = np.zeros_like(frame)

        # Pose detection
        landmarks = process_frame(frame)
        wrist = None
        if landmarks:
            frame = draw_shoulder_markers(frame, landmarks, frame_width, frame_height)
            if reference_distance is None:
                reference_distance = calculate_reference_distance(landmarks, frame_width, frame_height)
            wrist = track_racket(landmarks, frame_width, frame_height)
            swing_points.append(wrist)

            # partial path
            if DRAW_PARTIAL_PATH and len(swing_points) > 1:
                pt1 = swing_points[-2]
                pt2 = swing_points[-1]
                if pt1 is not None and pt2 is not None:
                    cv2.line(path_overlay, pt1, pt2, (0,255,0), 2)

        final_frame = cv2.addWeighted(frame, 1.0, path_overlay, 1.0, 0)

        # measure speed
        if prev_gray is not None and wrist is not None:
            speed_px_frame = measure_optical_flow_speed(prev_gray, curr_gray, wrist)
            if speed_px_frame is not None:
                speed_px_sec = speed_px_frame * fps
                speeds.append(speed_px_sec)
                cv2.putText(final_frame, f"Speed: {speed_px_sec:.2f} px/s", (10,30),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2)
            else:
                speeds.append(None)
        else:
            speeds.append(None)

        # Display frame index
        actual_frame_number = frame_index + skip_frames
        cv2.putText(final_frame, f"Frame: {actual_frame_number}", (10, 60),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,0), 2)

        cv2.imshow(window_title, final_frame)

        # Save frames
        if frames_per_capture > 0 and (actual_frame_number % frames_per_capture == 0):
            filename = f"{window_title}_capture_{actual_frame_number}.png"
            cv2.imwrite(filename, final_frame)
            print(f"Saved frame: {filename}")

        prev_gray = curr_gray
        frame_index += 1
        if cv2.waitKey(10) & 0xFF == 27:
            break

    if final_frame is not None:
        cv2.imwrite(screenshot_path, final_frame)
        print(f"Screenshot saved to {screenshot_path}")

    cap.release()
    cv2.destroyAllWindows()

    # Summarize raw speeds
    valid_speeds = [s for s in speeds if s is not None]
    if valid_speeds:
        avg_speed = sum(valid_speeds)/len(valid_speeds)
        print(f"[{window_title}] RAW Avg Speed: {avg_speed:.2f} px/s")
    else:
        print(f"[{window_title}] No valid speeds measured")

    return swing_points, reference_distance, speeds


###############################
#             MAIN            #
###############################
def main():
    user_skip = 20
    pro_skip  = 0

    print("Processing user video...")
    user_swing_pts, user_ref, user_speeds_raw = process_video(
        USER_VIDEO_PATH, USER_SCREENSHOT_PATH, "UserVideo",
        skip_frames=user_skip
    )
    # Post-process user speeds
    user_speeds = smooth_speed_curve(user_speeds_raw, window_length=7, polyorder=2, speed_threshold=550)
    user_mean_vel = compute_mean_velocity(user_speeds)
    user_mean_acc = compute_mean_acceleration(user_speeds)

    print("\nProcessing pro video...")
    pro_swing_pts, pro_ref, pro_speeds_raw = process_video(
        PRO_VIDEO_PATH, PRO_SCREENSHOT_PATH, "ProVideo",
        skip_frames=pro_skip
    )
    # Post-process pro speeds
    pro_speeds = smooth_speed_curve(pro_speeds_raw, window_length=7, polyorder=2, speed_threshold=550)
    pro_mean_vel = compute_mean_velocity(pro_speeds)
    pro_mean_acc = compute_mean_acceleration(pro_speeds)

    # DTW on the final speed curves (excluding Nones by zero-filling or last valid)
    def fill_none_with_zero(arr):
        out = []
        last_val = 0
        for s in arr:
            if s is not None:
                out.append(s)
                last_val = s
            else:
                out.append(last_val)
        return out

    user_speeds_for_dtw = fill_none_with_zero(user_speeds)
    pro_speeds_for_dtw  = fill_none_with_zero(pro_speeds)
    # Align lengths for DTW, or do shorter
    min_len = min(len(user_speeds_for_dtw), len(pro_speeds_for_dtw))
    user_speeds_for_dtw = user_speeds_for_dtw[:min_len]
    pro_speeds_for_dtw  = pro_speeds_for_dtw[:min_len]
    dtw_dist = simple_dtw(user_speeds_for_dtw, pro_speeds_for_dtw)

    # Print out the table of metrics
    # Metrics => Mean Velocity, Mean Acceleration, DTW Dist
    print("\n===== Summary Metrics =====")
    print(f"{'Metric':<20} | {'User':<10} | {'Pro':<10}")
    print("-"*50)
    print(f"{'Mean Velocity (px/s)':<20} | {user_mean_vel:<10.2f} | {pro_mean_vel:<10.2f}")
    print(f"{'Mean Accel (px/s^2)':<20} | {user_mean_acc:<10.2f} | {pro_mean_acc:<10.2f}")
    print(f"{'DTW Speed Dist':<20} | {dtw_dist:<10.2f} | {'-':<10}")  # DTW is a single value for both

    # Plot final speed curves
    if user_speeds and pro_speeds:
        min_len = min(len(user_speeds), len(pro_speeds))
        t = range(min_len)
        plt.figure(figsize=(10, 5))
        plt.plot(t, user_speeds[:min_len], label="User Speed (px/s)", color='blue')
        plt.plot(t, pro_speeds[:min_len], label="Pro Speed (px/s)", color='green')
        plt.title("Hand (Wrist) Speed Over Time (Smoothed, Preserving High-Speed Peaks)")
        plt.xlabel("Sample Index (post-skip)")
        plt.ylabel("Speed (px/s)")
        plt.legend()
        plt.grid(True)
        plt.show()

    # Final path analysis
    if user_swing_pts and pro_swing_pts:
        user_swing_pts = normalize_swing_points(user_swing_pts, user_ref)
        pro_swing_pts  = normalize_swing_points(pro_swing_pts, pro_ref)
        user_swing_pts, pro_swing_pts = align_trajectories(user_swing_pts, pro_swing_pts)

        user_analysis = analyze_swing(user_swing_pts)
        if user_analysis:
            dist, ang = user_analysis
            print(f"[User] distance: {dist:.2f}, angle: {ang:.2f}")

        pro_analysis = analyze_swing(pro_swing_pts)
        if pro_analysis:
            dist, ang = pro_analysis
            print(f"[Pro ] distance: {dist:.2f}, angle: {ang:.2f}")

        if user_analysis and pro_analysis:
            dist_diff, ang_diff = compare_swings(user_analysis, pro_analysis)
            print(f"[Diff] dist: {dist_diff:.2f}, angle: {ang_diff:.2f}")

        # final path plot
        plt.figure(figsize=(10, 6))
        if user_swing_pts:
            ux = [p[0] for p in user_swing_pts if p]
            uy = [p[1] for p in user_swing_pts if p]
            plt.plot(ux, uy, marker='o', color='blue', label="User Swing")
        if pro_swing_pts:
            px = [p[0] for p in pro_swing_pts if p]
            py = [p[1] for p in pro_swing_pts if p]
            plt.plot(px, py, marker='o', color='green', label="Pro Swing")

        plt.title("Normalized Swing Trajectories (User vs Pro)")
        plt.xlabel("Normalized X")
        plt.ylabel("Normalized Y")
        plt.gca().invert_yaxis()
        plt.legend()
        plt.grid(True)
        plt.show()
    else:
        print("No valid data from user or pro for final path analysis.")

if __name__ == "__main__":
    main()
